import torch
import copy
from Network.General.Conv.conv import ConvNetwork
from Network.General.Flat.mlp import MLPNetwork
from Network.network import Network
from Network.network_utils import pytorch_model

class KeyQueryEncoder(Network):
    def __init__(self, args, key_args, query_args):
        super().__init__(args)
        '''
        Network that handles converting a flat input (and flat mask) into
        a collection of keys and queries
        '''
        self.fp = args.factor
        self.output_dim = key_args.output_dim # should equal query_args.output_dim
        self.key_dim = args.embed_dim if args.embed_dim > 0 else args.factor.single_obj_dim 
        self.query_dim = args.embed_dim if args.embed_dim > 0 else args.factor.object_dim
        self.soft_mask_param = 0
        if key_args.output_dim > 0:
            self.key_encoder = ConvNetwork(key_args)
            self.query_encoder = ConvNetwork(query_args)
            self.model = [self.key_encoder, self.query_encoder]
        else: # just use the encoder to slice values
            self.key_encoder, self.query_encoder = None, None
            self.model = list()
        self.train()
        self.reset_network_parameters()


    def slice_input(self, x):
        ''' assumes that input is of the shape 
            [batch, first_obj_dim + obj_dim * num_objects]
        where first_obj_dim = single_obj_dim * num_keys
        Performs necessary slicing
        TODO: handle post_dim component
        '''
        start_dim = self.fp.first_obj_dim if "start_dim" not in self.fp or self.fp.start_dim < 0 else self.fp.start_dim # the "not in" check is for backward compatability
        keys = torch.stack([x[...,i * self.fp.single_obj_dim: (i+1) * self.fp.single_obj_dim] for i in range(int(self.fp.first_obj_dim // self.fp.single_obj_dim))], dim=-2) # [batch size, num keys, single object dim]
        queries = torch.stack([x[...,start_dim + j * self.fp.object_dim:start_dim + (j+1) *self.fp.object_dim] for j in range(int((x.shape[-1] - start_dim) // self.fp.object_dim))], dim=-2) # [batch size, num values, single object dim]
        return keys, queries

    def slice_masks(self, m, batch_size, num_keys, num_queries):
        # if masks of shape [batch, num_keys * num_queries], converts to [batch, num_keys, num_queries]
        # if masks of shape [batch, num_queries] or [num_queries], broadcasts and converts to [batch, num_keys, num_queries]
        # assumes m of shape [batch, num_keys * num_queries]
        # if a non mask, simply return it since it has special meaning
        if m is None: return m
        if self.soft_mask_param > 0: m[m == 0] = self.soft_mask_param
        if len(m.shape) == 1:
            if self.fp.name_idx != -1: m = m[self.fp.name_idx]
            m = m.unsqueeze(0)
            m = m.unsqueeze(1)
            m = m.broadcast_to(batch_size, num_keys, m.shape[-1])
        elif len(m.shape) == 2:
            if self.fp.name_idx != -1: m = m[...,self.fp.name_idx]
            if m.shape[-1] == num_queries:
                m = m.unsqueeze(1).broadcast_to(batch_size, num_keys, m.shape[-1])
            else: m = m.reshape(batch_size, num_keys, num_queries)
        else:
            if self.fp.name_idx != -1: m = m[...,self.fp.name_idx,self.fp.name_idx]
            return m
        return m

    def forward(self, x):
        x = pytorch_model.wrap(x, cuda=self.iscuda)
        raw_keys, raw_queries = self.slice_input(x)
        if self.output_dim == 0: return raw_keys, raw_queries
        # print(raw_keys[0].sum())
        keys = self.key_encoder(raw_keys.transpose(-1,-2)).transpose(-2,-1)
        # print("layers", keys[0].sum(), [l.weight.data.sum() for l in self.key_encoder.model if hasattr(l, "weight")])
        queries = self.query_encoder(raw_queries.transpose(-1,-2)).transpose(-2,-1)
        # return shape: batch, num_keys/queries, embed_dim
        return keys, queries
